#!/usr/bin/env python
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.
"""Renderer DAG (tasks and dependencies) to the graphviz object."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

try:
    import graphviz
except ImportError:
    warnings.warn(
        "Could not import graphviz. Rendering graph to the graphical format will not be possible.",
        UserWarning,
        stacklevel=2,
    )
    graphviz = None

from airflow.exceptions import AirflowException
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.utils.dag_edges import dag_edges
from airflow.utils.state import State
from airflow.utils.task_group import TaskGroup

if TYPE_CHECKING:
    from airflow.models import TaskInstance
    from airflow.models.dag import DAG
    from airflow.models.taskmixin import DependencyMixin
    from airflow.serialization.serialized_objects import DagDependency


def _refine_color(color: str):
    """
    Convert color in #RGB (12 bits) format to #RRGGBB (32 bits), if it possible.

    Otherwise, it returns the original value. Graphviz does not support colors in #RGB format.

    :param color: Text representation of color
    :return: Refined representation of color
    """
    if len(color) == 4 and color[0] == "#":
        color_r = color[1]
        color_g = color[2]
        color_b = color[3]
        return "#" + color_r + color_r + color_g + color_g + color_b + color_b
    return color


def _draw_task(
    task: MappedOperator | BaseOperator,
    parent_graph: graphviz.Digraph,
    states_by_task_id: dict[Any, Any] | None,
) -> None:
    """Draw a single task on the given parent_graph."""
    if states_by_task_id:
        state = states_by_task_id.get(task.task_id)
        color = State.color_fg(state)
        fill_color = State.color(state)
    else:
        color = task.ui_fgcolor
        fill_color = task.ui_color

    parent_graph.node(
        task.task_id,
        _attributes={
            "label": task.label,
            "shape": "rectangle",
            "style": "filled,rounded",
            "color": _refine_color(color),
            "fillcolor": _refine_color(fill_color),
        },
    )


def _draw_task_group(
    task_group: TaskGroup, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
) -> None:
    """Draw the given task_group and its children on the given parent_graph."""
    # Draw joins
    if task_group.upstream_group_ids or task_group.upstream_task_ids:
        parent_graph.node(
            task_group.upstream_join_id,
            _attributes={
                "label": "",
                "shape": "circle",
                "style": "filled,rounded",
                "color": _refine_color(task_group.ui_fgcolor),
                "fillcolor": _refine_color(task_group.ui_color),
                "width": "0.2",
                "height": "0.2",
            },
        )

    if task_group.downstream_group_ids or task_group.downstream_task_ids:
        parent_graph.node(
            task_group.downstream_join_id,
            _attributes={
                "label": "",
                "shape": "circle",
                "style": "filled,rounded",
                "color": _refine_color(task_group.ui_fgcolor),
                "fillcolor": _refine_color(task_group.ui_color),
                "width": "0.2",
                "height": "0.2",
            },
        )

    # Draw children
    for child in sorted(task_group.children.values(), key=lambda t: t.node_id if t.node_id else ""):
        _draw_nodes(child, parent_graph, states_by_task_id)


def _draw_nodes(
    node: DependencyMixin, parent_graph: graphviz.Digraph, states_by_task_id: dict[str, str] | None
) -> None:
    """Draw the node and its children on the given parent_graph recursively."""
    if isinstance(node, (BaseOperator, MappedOperator)):
        _draw_task(node, parent_graph, states_by_task_id)
    else:
        if not isinstance(node, TaskGroup):
            raise AirflowException(f"The node {node} should be TaskGroup and is not")
        # Draw TaskGroup
        if node.is_root:
            # No need to draw background for root TaskGroup.
            _draw_task_group(node, parent_graph, states_by_task_id)
        else:
            with parent_graph.subgraph(name=f"cluster_{node.group_id}") as sub:
                sub.attr(
                    shape="rectangle",
                    style="filled",
                    color=_refine_color(node.ui_fgcolor),
                    # Partially transparent CornflowerBlue
                    fillcolor="#6495ed7f",
                    label=node.label,
                )
                _draw_task_group(node, sub, states_by_task_id)


def render_dag_dependencies(deps: dict[str, list[DagDependency]]) -> graphviz.Digraph:
    """
    Render the DAG dependency to the DOT object.

    :param deps: List of DAG dependencies
    :return: Graphviz object
    """
    if not graphviz:
        raise AirflowException(
            "Could not import graphviz. Install the graphviz python package to fix this error."
        )
    dot = graphviz.Digraph(graph_attr={"rankdir": "LR"})

    for dag, dependencies in deps.items():
        for dep in dependencies:
            with dot.subgraph(
                name=dag,
                graph_attr={
                    "rankdir": "LR",
                    "labelloc": "t",
                    "label": dag,
                },
            ) as dep_subgraph:
                dep_subgraph.edge(dep.source, dep.dependency_id)
                dep_subgraph.edge(dep.dependency_id, dep.target)

    return dot


def render_dag(dag: DAG, tis: list[TaskInstance] | None = None) -> graphviz.Digraph:
    """
    Render the DAG object to the DOT object.

    If an task instance list is passed, the nodes will be painted according to task statuses.

    :param dag: DAG that will be rendered.
    :param tis: List of task instances
    :return: Graphviz object
    """
    if not graphviz:
        raise AirflowException(
            "Could not import graphviz. Install the graphviz python package to fix this error."
        )
    dot = graphviz.Digraph(
        dag.dag_id,
        graph_attr={
            "rankdir": dag.orientation if dag.orientation else "LR",
            "labelloc": "t",
            "label": dag.dag_id,
        },
    )
    states_by_task_id = None
    if tis is not None:
        states_by_task_id = {ti.task_id: ti.state for ti in tis}

    _draw_nodes(dag.task_group, dot, states_by_task_id)

    for edge in dag_edges(dag):
        # Gets an optional label for the edge; this will be None if none is specified.
        label = dag.get_edge_info(edge["source_id"], edge["target_id"]).get("label")
        # Add the edge to the graph with optional label
        # (we can just use the maybe-None label variable directly)
        dot.edge(edge["source_id"], edge["target_id"], label)

    return dot
